# coding:utf-8
import sys

import baostock as bs
import pandas as pd
import datetime
from datetime import date
from src.manager.log_manager import LogManager
from src.util.properties_util import PropertiesUtil
from src.manager.oracle_manager import OracleManager
from src.config.oracle_config import OracleConfig

Logger = LogManager.get_logger(__name__)


class CollectStockHandler:
    """
    调用baostock的API获取股票数据的handler
    """

    def __init__(self):
        """
        构造函数，初始化oracle_manager和cursor对象
        """
        super().__init__()
        self.oracle_manager = OracleManager(OracleConfig.Username, OracleConfig.Password, OracleConfig.Url)
        self.oracle_manager.connect()
        self.cursor = self.oracle_manager.get_cursor()

    def collect(self, stock_code, begin_date, end_date, return_or_store=1):
        """
        根据参数，调用baostock的API获取股票数据，或者存储在csv文件中，或者返回
        return_or_store：1表示将股票记录作为返回值；2表示将股票记录存储在csv文件中
        """
        # 登录系统
        lg = bs.login()

        # 获取沪深A股历史K线数据
        # 详细指标参数，参见“历史行情指标参数”章节；“分钟线”参数与“日线”参数不同。“分钟线”不包含指数。
        # 分钟线指标：date,time,code,open,high,low,close,volume,amount,adjustflag
        # 周月线指标：date,code,open,high,low,close,volume,amount,adjustflag,turn,pctChg
        rs = bs.query_history_k_data_plus(stock_code,
                                          "date,code,open,high,low,close,preclose,volume,amount,adjustflag,turn,"
                                          "tradestatus,pctChg,isST",
                                          start_date=begin_date, end_date=end_date,
                                          frequency="d", adjustflag="3")

        # 打印结果集
        data_list = []
        while (rs.error_code == '0') & rs.next():
            # 获取一条记录，将记录合并在一起
            data_list.append(rs.get_row_data())

        # 将股票记录作为返回值
        if return_or_store == 1:
            # 登出系统
            bs.logout()
            return data_list

        # 将股票记录存储在csv文件中
        if return_or_store == 2:
            result = pd.DataFrame(data_list, columns=rs.fields)
            # 结果集输出到csv文件
            result.to_csv("/temp2/history_A_stock_k_data.csv", index=False)
            # 登出系统
            bs.logout()

    def insert_stock_transaction_data(self):
        """
        根据begin_date和end_date，调用baostock的API获取股票数据，并存储到stock_transaction_data表中
        """

        # 登录系统
        lg = bs.login()

        Logger.info('获取stock_info表中的所有记录')

        begin_date = PropertiesUtil.get_value(
            "\\mywork\\gitcode-repository\\hades\\dev-project\\adam\\src\\main\\resources\\stock-record.properties",
            "stockRecord.begin.date")
        end_date = PropertiesUtil.get_value(
            "\\mywork\\gitcode-repository\\hades\\dev-project\\adam\\src\\main\\resources\\stock-record.properties",
            "stockRecord.end.date")
        Logger.info('采集股票数据的开始时间为【' + begin_date + '】，结束时间为【' + end_date + '】')
        begin_date = datetime.datetime.strptime(begin_date, '%Y%m%d').strftime('%Y-%m-%d')
        end_date = datetime.datetime.strptime(end_date, '%Y%m%d').strftime('%Y-%m-%d')

        self.cursor.execute("select * from stock_info t")
        # self.cursor.execute("select * from stock_info t where t.code_ not in(select distinct t2.code_ from stock_transaction_data t2)")
        stock_info_list = self.cursor.fetchall()

        if stock_info_list is not None or len(stock_info_list) > 0:
            # 用于记录插入的记录数
            insert_stock_number = 0
            # 读取并保存股票记录
            for stock_info in stock_info_list:
                rs = bs.query_history_k_data_plus(stock_info[4],
                                                  "date,code,open,high,low,close,preclose,volume,amount,pctChg,turn,"
                                                  "tradestatus",
                                                  start_date=begin_date, end_date=end_date,
                                                  frequency="d", adjustflag="3")

                # 获取一条记录，将记录合并在一起
                param_list = []
                while (rs.error_code == '0') & rs.next():
                    row_data = rs.get_row_data()
                    # 判断是否是停牌，如果停牌则跳过
                    if int(row_data[11]) == 0:
                        Logger.info('股票【' + row_data[1][3:] + '】在日期【' + begin_date + '】至【' + end_date + '】停牌，因此跳过')
                        continue

                    # 判断是否没有数据
                    if row_data[2] == '':
                        Logger.info('股票【' + row_data[1] + '】在日期【' + row_data[0] + '】没有数据，因此跳过')
                        continue

                    if row_data[9] == '':
                        Logger.error(row_data)

                    up_down = None
                    if row_data[9] != '' and float(row_data[9]) > 0:
                        up_down = 1
                    if row_data[9] != '' and float(row_data[9]) < 0:
                        up_down = -1
                    if row_data[9] == '' or float(row_data[9]) == 0:
                        up_down = 0
                    param_list.append(
                        (date.fromisoformat(row_data[0]), row_data[1][3:], float(row_data[2]), float(row_data[3]),
                         float(row_data[4]), float(row_data[5]), float(row_data[6]), int(row_data[7] if row_data[7] != '' else 0),
                         float(row_data[8] if row_data[8] != '' else 0), round(float(row_data[9] if row_data[9] != '' else 0), 2),
                         round(float(row_data[5]) - float(row_data[6]), 2), float(row_data[10] if row_data[10] != '' else 0), up_down))
                    insert_stock_number = insert_stock_number + 1
                self.oracle_manager.batch_insert("insert into stock_transaction_data(date_, code_, open_price, "
                                                 "highest_price, lowest_price, close_price, last_close_price, volume, "
                                                 "turnover, change_range, change_amount, turnover_rate, up_down) "
                                                 "values(:1, :2, :3, :4, :5, :6, :7, :8, :9, :10, :11, :12, :13)",
                                                 param_list)
        else:
            Logger.info('从stock_info表中没有找到任何记录')

        Logger.info('插入的股票记录数为【' + str(insert_stock_number) + '】')

        self.oracle_manager.cursor_close()
        self.oracle_manager.connect_close()

        # 登出系统
        bs.logout()
        sys.exit(0)
